Introduction to Computer Vision: Plant Seedlings Classification¶

Problem Statement¶

Context¶

In recent times, the field of agriculture has been in urgent need of modernizing, since the amount of manual work people need to put in to check if plants are growing correctly is still highly extensive. Despite several advances in agricultural technology, people working in the agricultural industry still need to have the ability to sort and recognize different plants and weeds, which takes a lot of time and effort in the long term. The potential is ripe for this trillion-dollar industry to be greatly impacted by technological innovations that cut down on the requirement for manual labor, and this is where Artificial Intelligence can actually benefit the workers in this field, as the time and energy required to identify plant seedlings will be greatly shortened by the use of AI and Deep Learning. The ability to do so far more efficiently and even more effectively than experienced manual labor, could lead to better crop yields, the freeing up of human inolvement for higher-order agricultural decision making, and in the long term will result in more sustainable environmental practices in agriculture as well.

Objective¶

The aim of this project is to Build a Convolutional Neural Netowrk to classify plant seedlings into their respective categories.

Data Dictionary¶

The Aarhus University Signal Processing group, in collaboration with the University of Southern Denmark, has recently released a dataset containing images of unique plants belonging to 12 different species.

  • The dataset can be download from Olympus.

  • The data file names are:

    • images.npy
    • Labels.csv
  • Due to the large volume of data, the images were converted to the images.npy file and the labels are also put into Labels.csv, so that you can work on the data/project seamlessly without having to worry about the high data volume.

  • The goal of the project is to create a classifier capable of determining a plant's species from an image.

List of Species

  • Black-grass
  • Charlock
  • Cleavers
  • Common Chickweed
  • Common Wheat
  • Fat Hen
  • Loose Silky-bent
  • Maize
  • Scentless Mayweed
  • Shepherds Purse
  • Small-flowered Cranesbill
  • Sugar beet

Note: Please use GPU runtime on Google Colab to execute the code faster.¶

Importing necessary libraries¶

In [2]:
# libraries for numerical analysis, data manipulation and plotting
import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

# libraries for image
import cv2
from google.colab.patches import cv2_imshow

# libraries for data preprocessing
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelBinarizer
from sklearn.utils import class_weight
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report

# libraries for DeepLearning and CNN
import tensorflow as tf
from tensorflow.keras.backend import clear_session
from tensorflow.keras.layers import Dense, Conv2D, MaxPool2D, Flatten, Dropout, BatchNormalization, SpatialDropout2D, Activation, Input
from tensorflow.keras.models import Sequential, clone_model
from tensorflow.keras.callbacks import ReduceLROnPlateau, EarlyStopping, ModelCheckpoint
from tensorflow.keras.optimizers import SGD
from tensorflow.keras.preprocessing.image import ImageDataGenerator

# warning
import warnings
warnings.filterwarnings('ignore')

Setting seeds¶

In [4]:
random_state = 42
# reset session
def reset_session():
  clear_session()
  tf.compat.v1.reset_default_graph()
  # reset seeds
  random.seed(random_state)
  np.random.seed(random_state)
  tf.random.set_seed(random_state)

reset_session()

Loading the dataset¶

In [5]:
from google.colab import drive
drive.mount('/content/drive')
Mounted at /content/drive
In [6]:
label_path = "/content/drive/MyDrive/AI_ML_PGP/Projects/PlantSeedlingsClassification/Labels.csv"
img_path = "/content/drive/MyDrive/AI_ML_PGP/Projects/PlantSeedlingsClassification/images.npy"
In [7]:
X = np.load(img_path)
y = pd.read_csv(label_path)

Data Overview¶

In [8]:
# checking the data is loaded and displaying the first image
cv2_imshow(X[0])
No description has been provided for this image
In [9]:
# checking the data is loaded for the label
y.head()
Out[9]:
Label
0 Small-flowered Cranesbill
1 Small-flowered Cranesbill
2 Small-flowered Cranesbill
3 Small-flowered Cranesbill
4 Small-flowered Cranesbill

Understand the shape of the dataset¶

In [10]:
X.shape, y.shape
Out[10]:
((4750, 128, 128, 3), (4750, 1))
In [11]:
X.min(), X.max()
Out[11]:
(0, 255)
In [12]:
y.nunique()
Out[12]:
0
Label 12

In [13]:
# checking the channels of the image
plt.imshow(X[0])
plt.axis('off')
plt.show()
No description has been provided for this image

Observations¶

  • Images are stored in 4 dimensional numpy array
  • The first dimension refers to the number of images and we have total 4750 images
  • The second dimension refers to the number of pixels on the x-axis and the third dimension refers to the number of pixels on the y-axis. We have image size of 128*128 pixels.
  • the 4th dimension refers to number of channels in the dataset and we have 3 channels.
  • We observed that there are 12 unqiue labels in the given dataset.
  • We displayed the image using cv2 library and also ploted the image using matplotlib. We noticed when we plot the image, the image appears to be blue tinted. However, when we display the same image using cv2, it looks normal. This proves that the images are in BGR format instead of RGB. BGR is the default format for cv2, hence, it renders the images normal.

Exploratory Data Analysis¶

  • EDA is an important part of any project involving data.
  • It is important to investigate and understand the data better before building a model with it.
  • A few questions have been mentioned below which will help you understand the data better.
  • A thorough analysis of the data, in addition to the questions mentioned below, should be done.
  1. How are these different category plant images different from each other?
  2. Is the dataset provided an imbalance? (Check with using bar plots)
In [14]:
# Unique labels
num_classes = 12
categories = np.unique(y)
print('Unique Categories:',categories)
Unique Categories: ['Black-grass' 'Charlock' 'Cleavers' 'Common Chickweed' 'Common wheat'
 'Fat Hen' 'Loose Silky-bent' 'Maize' 'Scentless Mayweed'
 'Shepherds Purse' 'Small-flowered Cranesbill' 'Sugar beet']
In [15]:
# Print random image from each categories
fig, axes = plt.subplots(3, 4, figsize = (8, 8))
for cat, ax in zip(categories, axes.flatten()):
  random.seed(random_state)
  rand_of_each_cat = random.choice(y[y['Label'] == cat].index)
  img = cv2.cvtColor(X[rand_of_each_cat], cv2.COLOR_BGR2RGB)
  ax.imshow(img)
  ax.set_title(cat)
  ax.axis('off')
plt.tight_layout()
plt.show()
No description has been provided for this image
In [16]:
# print random 3 images from each category
plt.figure(figsize=(8,20))
num_images_per_category = 3
for index,cat in enumerate(categories):
  random.seed(random_state)
  rand_choices = random.choices(y[y['Label'] == cat].index, k=num_images_per_category)
  for i, rand_choice in enumerate(rand_choices):
        ax = plt.subplot(
            len(categories),
            num_images_per_category,
            index * num_images_per_category + i + 1)
        img = cv2.cvtColor(X[rand_choice], cv2.COLOR_BGR2RGB)
        ax.imshow(img)
        ax.set_title(cat)
        ax.axis('off')
plt.tight_layout()
plt.show()
No description has been provided for this image
In [17]:
y['Label'].value_counts(normalize=True).plot.bar()
plt.title('Class Distribution')
plt.xlabel('Class')
plt.ylabel('Percentage')
plt.show()
No description has been provided for this image

Observations¶

  • We checked the shape of the images and also plotted the images. All images have same size of 128*128 pixel.
  • We have 12 categories of images
  • We first plotted a random image from each categories and we noticed that seedlings vary from each other in-terms of their leaf shapes and sizes. Their background also varies from each other. Some images contain multiple seedlings of same type.
  • We also plotted 3 random images from each categories to understand the visual variance among images of same categories. We observed that leaf shape and sizes vary among images in the same category. The image background, contrast and blurriness also vary among images in the same category.
  • We noticed the dataset is imbalanced. Loose Silky-bent,Common Chickweed, Scentless Mayweed are the top 3 categories. Some categories like Common Wheat and Maize have as little as just over 4% representation in the dataset.

Data Pre-Processing¶

Convert the BGR images to RGB images.¶

In [18]:
plt.figure(figsize=(8,8))
for i in range(5):
  plt.subplot(1,5,i+1)
  plt.imshow(X[i])
  plt.axis('off')
plt.show()
No description has been provided for this image
In [19]:
X_RGB = np.array([cv2.cvtColor(img, cv2.COLOR_BGR2RGB) for img in X])
In [20]:
plt.figure(figsize=(8,8))
for i in range(5):
  plt.subplot(1,5,i+1)
  plt.imshow(X_RGB[i])
  plt.axis('off')
plt.show()
No description has been provided for this image
  • After converting the images to RGB, matplotlib library can plot the images without the blue taint.

Resize the images¶

As the size of the images is large, it may be computationally expensive to train on these larger images; therefore, it is preferable to reduce the image size from 128 to 64.

In [21]:
X_resized = np.array([cv2.resize(img, (64, 64)) for img in X_RGB])
# plotting first 5 resized images
plt.figure(figsize=(8,8))
for i in range(5):
  plt.subplot(1,5,i+1)
  plt.imshow(X_resized[i])
  plt.axis('off')
plt.show()
No description has been provided for this image

Data Preparation for Modeling¶

  • Before we proceed to build a model, we need to split the data into train and test. We will use 10% of train data for validation.
  • We will have to encode categorical features and scale the pixel values.
  • We will increase class weights for the minority classes
  • You will build a model using the train data and then check its performance
In [22]:
# Train, Test split
X_train, X_test, y_train, y_test = train_test_split(X_resized, y, test_size=0.1, random_state = random_state, stratify = y)
In [23]:
X_train.shape, X_test.shape, y_train.shape, y_test.shape
Out[23]:
((4275, 64, 64, 3), (475, 64, 64, 3), (4275, 1), (475, 1))

Encode the target labels¶

In [24]:
lb = LabelBinarizer()
y_train_encoded = lb.fit_transform(y_train)
y_test_encoded = lb.transform(y_test)
In [25]:
y_train_encoded.shape, y_test_encoded.shape
Out[25]:
((4275, 12), (475, 12))

Data Normalization¶

In [26]:
X_train_norm = X_train / 255.0
X_test_norm = X_test / 255.0

Class weight¶

In [27]:
classes = np.unique(y_train)
class_weights = class_weight.compute_class_weight('balanced', classes=classes, y=y_train['Label'])
class_weight_dict = dict(enumerate(class_weights))
print("Class Weights:", class_weight_dict)
Class Weights: {0: 1.5031645569620253, 1: 1.014957264957265, 2: 1.380813953488372, 3: 0.6477272727272727, 4: 1.7902010050251256, 5: 0.8343091334894613, 6: 0.6048387096774194, 7: 1.7902010050251256, 8: 0.7677801724137931, 9: 1.7127403846153846, 10: 0.7987668161434978, 11: 1.026657060518732}

Model Building¶

  • We will build CNN model first with balanced class weights above
  • We will use batch normalization and dropout which will help with internal covriate shift and regularization.
  • We will tune the model with hyper parameters to improve performance
  • We will build a model with data augmentation
  • If we are not happy with the performance we might want to try a pre-trained model and fine tune it

CNN model with class_weights and without data augmentation¶

Model architecture building and training¶

In [ ]:
# Early stopping: Stop early if validation loss doesn't decrease after 5 epoch
es = EarlyStopping(monitor='val_loss', mode='min', verbose=1, patience=10, restore_best_weights=True)
# Model checkpoint: Save the best model with highest accuracy in validation data, so it can be loaded to continue training from the saved state
mc = ModelCheckpoint('/content/drive/MyDrive/AI_ML_PGP/Projects/PlantSeedlingsClassification/cnn_weighted_model.keras', monitor='val_accuracy', mode='max', verbose=1, save_best_only=True)
In [ ]:
reset_session()
model = Sequential([
    # Input layer
    Input(shape=(64, 64, 3), name='Input_Layer'),
    # Layer 1
    Conv2D(128, (3, 3), activation=None, padding = 'same', name = 'Conv1_128'),
    BatchNormalization(),
    Activation('relu'),
    MaxPool2D(2, 2),
    # Layer 2
    Conv2D(64, (3, 3), activation=None, padding = 'same', name = 'Conv2_64'),
    BatchNormalization(),
    Activation('relu'),
    MaxPool2D(2, 2),

    # Layer 3
    Conv2D(32, (3, 3), activation=None, padding = 'same', name = 'Conv3_32'),
    BatchNormalization(),
    Activation('relu'),
    MaxPool2D(2, 2),

    # Layer 4
    Conv2D(32, (3, 3), activation=None, padding = 'same', name = 'Conv4_32'),
    BatchNormalization(),
    Activation('relu'),
    MaxPool2D(2, 2),

    # FC layer 1
    Flatten(),
    Dense(128, activation='relu', name = 'FC1_128'),
    Dropout(0.25),

    # FC layer 2
    Dense(64, activation='relu', name = 'FC2_64'),
    Dropout(0.25),

    # FC layer 3
    Dense(32, activation='relu', name = 'FC3_32'),
    Dropout(0.25),

    # Output layer
    Dense(12, activation='softmax', name = 'Output_12')
])

model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
model.summary()
Model: "sequential"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┓
┃ Layer (type)                         ┃ Output Shape                ┃         Param # ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━┩
│ Conv1_128 (Conv2D)                   │ (None, 64, 64, 128)         │           3,584 │
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
│ batch_normalization                  │ (None, 64, 64, 128)         │             512 │
│ (BatchNormalization)                 │                             │                 │
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
│ activation (Activation)              │ (None, 64, 64, 128)         │               0 │
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
│ max_pooling2d (MaxPooling2D)         │ (None, 32, 32, 128)         │               0 │
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
│ Conv2_64 (Conv2D)                    │ (None, 32, 32, 64)          │          73,792 │
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
│ batch_normalization_1                │ (None, 32, 32, 64)          │             256 │
│ (BatchNormalization)                 │                             │                 │
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
│ activation_1 (Activation)            │ (None, 32, 32, 64)          │               0 │
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
│ max_pooling2d_1 (MaxPooling2D)       │ (None, 16, 16, 64)          │               0 │
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
│ Conv3_32 (Conv2D)                    │ (None, 16, 16, 32)          │          18,464 │
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
│ batch_normalization_2                │ (None, 16, 16, 32)          │             128 │
│ (BatchNormalization)                 │                             │                 │
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
│ activation_2 (Activation)            │ (None, 16, 16, 32)          │               0 │
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
│ max_pooling2d_2 (MaxPooling2D)       │ (None, 8, 8, 32)            │               0 │
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
│ Conv4_32 (Conv2D)                    │ (None, 8, 8, 32)            │           9,248 │
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
│ batch_normalization_3                │ (None, 8, 8, 32)            │             128 │
│ (BatchNormalization)                 │                             │                 │
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
│ activation_3 (Activation)            │ (None, 8, 8, 32)            │               0 │
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
│ max_pooling2d_3 (MaxPooling2D)       │ (None, 4, 4, 32)            │               0 │
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
│ flatten (Flatten)                    │ (None, 512)                 │               0 │
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
│ FC1_128 (Dense)                      │ (None, 128)                 │          65,664 │
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
│ dropout (Dropout)                    │ (None, 128)                 │               0 │
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
│ FC2_64 (Dense)                       │ (None, 64)                  │           8,256 │
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
│ dropout_1 (Dropout)                  │ (None, 64)                  │               0 │
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
│ FC3_32 (Dense)                       │ (None, 32)                  │           2,080 │
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
│ dropout_2 (Dropout)                  │ (None, 32)                  │               0 │
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
│ Output_12 (Dense)                    │ (None, 12)                  │             396 │
└──────────────────────────────────────┴─────────────────────────────┴─────────────────┘
 Total params: 182,508 (712.92 KB)
 Trainable params: 181,996 (710.92 KB)
 Non-trainable params: 512 (2.00 KB)
  • Our first model has 4 convolution layers and 3 fully connected layers
  • It has total 182,508 params and among thsoe 181,996 are trainable
In [ ]:
history = model.fit(X_train_norm, y_train_encoded, batch_size=32, epochs=50, validation_split=0.1, class_weight=class_weight_dict, callbacks=[es, mc])
Epoch 1/50
121/121 ━━━━━━━━━━━━━━━━━━━━ 0s 51ms/step - accuracy: 0.0806 - loss: 2.6187
Epoch 1: val_accuracy improved from -inf to 0.12383, saving model to /content/drive/MyDrive/AI_ML_PGP/Projects/PlantSeedlingsClassification/cnn_weighted_model.keras
121/121 ━━━━━━━━━━━━━━━━━━━━ 21s 67ms/step - accuracy: 0.0808 - loss: 2.6178 - val_accuracy: 0.1238 - val_loss: 2.4849
Epoch 2/50
118/121 ━━━━━━━━━━━━━━━━━━━━ 0s 11ms/step - accuracy: 0.1713 - loss: 2.3358
Epoch 2: val_accuracy did not improve from 0.12383
121/121 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - accuracy: 0.1725 - loss: 2.3308 - val_accuracy: 0.1238 - val_loss: 3.8904
Epoch 3/50
120/121 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - accuracy: 0.3208 - loss: 1.8846
Epoch 3: val_accuracy did not improve from 0.12383
121/121 ━━━━━━━━━━━━━━━━━━━━ 2s 13ms/step - accuracy: 0.3211 - loss: 1.8830 - val_accuracy: 0.1238 - val_loss: 6.3170
Epoch 4/50
117/121 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - accuracy: 0.4132 - loss: 1.6078
Epoch 4: val_accuracy improved from 0.12383 to 0.14019, saving model to /content/drive/MyDrive/AI_ML_PGP/Projects/PlantSeedlingsClassification/cnn_weighted_model.keras
121/121 ━━━━━━━━━━━━━━━━━━━━ 2s 13ms/step - accuracy: 0.4142 - loss: 1.6048 - val_accuracy: 0.1402 - val_loss: 3.5133
Epoch 5/50
116/121 ━━━━━━━━━━━━━━━━━━━━ 0s 11ms/step - accuracy: 0.4953 - loss: 1.4219
Epoch 5: val_accuracy improved from 0.14019 to 0.30374, saving model to /content/drive/MyDrive/AI_ML_PGP/Projects/PlantSeedlingsClassification/cnn_weighted_model.keras
121/121 ━━━━━━━━━━━━━━━━━━━━ 2s 12ms/step - accuracy: 0.4957 - loss: 1.4202 - val_accuracy: 0.3037 - val_loss: 2.1526
Epoch 6/50
116/121 ━━━━━━━━━━━━━━━━━━━━ 0s 11ms/step - accuracy: 0.5223 - loss: 1.3357
Epoch 6: val_accuracy improved from 0.30374 to 0.35514, saving model to /content/drive/MyDrive/AI_ML_PGP/Projects/PlantSeedlingsClassification/cnn_weighted_model.keras
121/121 ━━━━━━━━━━━━━━━━━━━━ 2s 12ms/step - accuracy: 0.5229 - loss: 1.3331 - val_accuracy: 0.3551 - val_loss: 2.0349
Epoch 7/50
116/121 ━━━━━━━━━━━━━━━━━━━━ 0s 11ms/step - accuracy: 0.5706 - loss: 1.2033
Epoch 7: val_accuracy improved from 0.35514 to 0.63551, saving model to /content/drive/MyDrive/AI_ML_PGP/Projects/PlantSeedlingsClassification/cnn_weighted_model.keras
121/121 ━━━━━━━━━━━━━━━━━━━━ 3s 12ms/step - accuracy: 0.5717 - loss: 1.2005 - val_accuracy: 0.6355 - val_loss: 1.1135
Epoch 8/50
116/121 ━━━━━━━━━━━━━━━━━━━━ 0s 11ms/step - accuracy: 0.6377 - loss: 1.0656
Epoch 8: val_accuracy did not improve from 0.63551
121/121 ━━━━━━━━━━━━━━━━━━━━ 2s 11ms/step - accuracy: 0.6379 - loss: 1.0645 - val_accuracy: 0.1519 - val_loss: 6.2629
Epoch 9/50
121/121 ━━━━━━━━━━━━━━━━━━━━ 0s 11ms/step - accuracy: 0.6656 - loss: 0.9464
Epoch 9: val_accuracy improved from 0.63551 to 0.74766, saving model to /content/drive/MyDrive/AI_ML_PGP/Projects/PlantSeedlingsClassification/cnn_weighted_model.keras
121/121 ━━━━━━━━━━━━━━━━━━━━ 2s 12ms/step - accuracy: 0.6656 - loss: 0.9463 - val_accuracy: 0.7477 - val_loss: 0.8064
Epoch 10/50
118/121 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - accuracy: 0.7042 - loss: 0.8665
Epoch 10: val_accuracy did not improve from 0.74766
121/121 ━━━━━━━━━━━━━━━━━━━━ 2s 12ms/step - accuracy: 0.7044 - loss: 0.8654 - val_accuracy: 0.4393 - val_loss: 2.1030
Epoch 11/50
119/121 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - accuracy: 0.7330 - loss: 0.7736
Epoch 11: val_accuracy did not improve from 0.74766
121/121 ━━━━━━━━━━━━━━━━━━━━ 3s 13ms/step - accuracy: 0.7330 - loss: 0.7734 - val_accuracy: 0.6355 - val_loss: 1.2828
Epoch 12/50
116/121 ━━━━━━━━━━━━━━━━━━━━ 0s 11ms/step - accuracy: 0.7481 - loss: 0.7309
Epoch 12: val_accuracy did not improve from 0.74766
121/121 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - accuracy: 0.7483 - loss: 0.7304 - val_accuracy: 0.7033 - val_loss: 0.8096
Epoch 13/50
120/121 ━━━━━━━━━━━━━━━━━━━━ 0s 11ms/step - accuracy: 0.7731 - loss: 0.6545
Epoch 13: val_accuracy did not improve from 0.74766
121/121 ━━━━━━━━━━━━━━━━━━━━ 3s 12ms/step - accuracy: 0.7732 - loss: 0.6540 - val_accuracy: 0.7360 - val_loss: 0.8313
Epoch 14/50
116/121 ━━━━━━━━━━━━━━━━━━━━ 0s 11ms/step - accuracy: 0.8024 - loss: 0.5923
Epoch 14: val_accuracy did not improve from 0.74766
121/121 ━━━━━━━━━━━━━━━━━━━━ 1s 11ms/step - accuracy: 0.8025 - loss: 0.5923 - val_accuracy: 0.5864 - val_loss: 1.5370
Epoch 15/50
116/121 ━━━━━━━━━━━━━━━━━━━━ 0s 11ms/step - accuracy: 0.7800 - loss: 0.6429
Epoch 15: val_accuracy did not improve from 0.74766
121/121 ━━━━━━━━━━━━━━━━━━━━ 1s 11ms/step - accuracy: 0.7805 - loss: 0.6406 - val_accuracy: 0.6869 - val_loss: 0.9770
Epoch 16/50
116/121 ━━━━━━━━━━━━━━━━━━━━ 0s 11ms/step - accuracy: 0.8237 - loss: 0.5233
Epoch 16: val_accuracy improved from 0.74766 to 0.78037, saving model to /content/drive/MyDrive/AI_ML_PGP/Projects/PlantSeedlingsClassification/cnn_weighted_model.keras
121/121 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - accuracy: 0.8235 - loss: 0.5230 - val_accuracy: 0.7804 - val_loss: 0.6779
Epoch 17/50
116/121 ━━━━━━━━━━━━━━━━━━━━ 0s 11ms/step - accuracy: 0.8465 - loss: 0.4621
Epoch 17: val_accuracy did not improve from 0.78037
121/121 ━━━━━━━━━━━━━━━━━━━━ 1s 11ms/step - accuracy: 0.8458 - loss: 0.4630 - val_accuracy: 0.7009 - val_loss: 1.0019
Epoch 18/50
118/121 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - accuracy: 0.8260 - loss: 0.4529
Epoch 18: val_accuracy improved from 0.78037 to 0.79439, saving model to /content/drive/MyDrive/AI_ML_PGP/Projects/PlantSeedlingsClassification/cnn_weighted_model.keras
121/121 ━━━━━━━━━━━━━━━━━━━━ 2s 14ms/step - accuracy: 0.8262 - loss: 0.4527 - val_accuracy: 0.7944 - val_loss: 0.6924
Epoch 19/50
121/121 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - accuracy: 0.8484 - loss: 0.4105
Epoch 19: val_accuracy did not improve from 0.79439
121/121 ━━━━━━━━━━━━━━━━━━━━ 2s 14ms/step - accuracy: 0.8484 - loss: 0.4105 - val_accuracy: 0.7336 - val_loss: 0.9930
Epoch 20/50
118/121 ━━━━━━━━━━━━━━━━━━━━ 0s 11ms/step - accuracy: 0.8506 - loss: 0.4197
Epoch 20: val_accuracy did not improve from 0.79439
121/121 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - accuracy: 0.8504 - loss: 0.4197 - val_accuracy: 0.6425 - val_loss: 2.0141
Epoch 21/50
116/121 ━━━━━━━━━━━━━━━━━━━━ 0s 11ms/step - accuracy: 0.8629 - loss: 0.3667
Epoch 21: val_accuracy did not improve from 0.79439
121/121 ━━━━━━━━━━━━━━━━━━━━ 1s 11ms/step - accuracy: 0.8627 - loss: 0.3673 - val_accuracy: 0.6379 - val_loss: 1.6969
Epoch 22/50
116/121 ━━━━━━━━━━━━━━━━━━━━ 0s 11ms/step - accuracy: 0.8702 - loss: 0.3709
Epoch 22: val_accuracy did not improve from 0.79439
121/121 ━━━━━━━━━━━━━━━━━━━━ 3s 12ms/step - accuracy: 0.8696 - loss: 0.3715 - val_accuracy: 0.7874 - val_loss: 0.8341
Epoch 23/50
116/121 ━━━━━━━━━━━━━━━━━━━━ 0s 11ms/step - accuracy: 0.8916 - loss: 0.3083
Epoch 23: val_accuracy did not improve from 0.79439
121/121 ━━━━━━━━━━━━━━━━━━━━ 3s 11ms/step - accuracy: 0.8912 - loss: 0.3090 - val_accuracy: 0.6121 - val_loss: 1.7701
Epoch 24/50
120/121 ━━━━━━━━━━━━━━━━━━━━ 0s 11ms/step - accuracy: 0.8520 - loss: 0.3845
Epoch 24: val_accuracy did not improve from 0.79439
121/121 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - accuracy: 0.8522 - loss: 0.3840 - val_accuracy: 0.7827 - val_loss: 0.8747
Epoch 25/50
120/121 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - accuracy: 0.8955 - loss: 0.2862
Epoch 25: val_accuracy did not improve from 0.79439
121/121 ━━━━━━━━━━━━━━━━━━━━ 3s 13ms/step - accuracy: 0.8954 - loss: 0.2866 - val_accuracy: 0.7593 - val_loss: 0.7854
Epoch 26/50
121/121 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - accuracy: 0.8957 - loss: 0.3053
Epoch 26: val_accuracy improved from 0.79439 to 0.86682, saving model to /content/drive/MyDrive/AI_ML_PGP/Projects/PlantSeedlingsClassification/cnn_weighted_model.keras
121/121 ━━━━━━━━━━━━━━━━━━━━ 2s 14ms/step - accuracy: 0.8956 - loss: 0.3054 - val_accuracy: 0.8668 - val_loss: 0.4277
Epoch 27/50
116/121 ━━━━━━━━━━━━━━━━━━━━ 0s 11ms/step - accuracy: 0.8948 - loss: 0.2990
Epoch 27: val_accuracy did not improve from 0.86682
121/121 ━━━━━━━━━━━━━━━━━━━━ 2s 11ms/step - accuracy: 0.8945 - loss: 0.2987 - val_accuracy: 0.7150 - val_loss: 1.2410
Epoch 28/50
116/121 ━━━━━━━━━━━━━━━━━━━━ 0s 11ms/step - accuracy: 0.9070 - loss: 0.2600
Epoch 28: val_accuracy did not improve from 0.86682
121/121 ━━━━━━━━━━━━━━━━━━━━ 3s 11ms/step - accuracy: 0.9065 - loss: 0.2609 - val_accuracy: 0.4042 - val_loss: 4.5967
Epoch 29/50
116/121 ━━━━━━━━━━━━━━━━━━━━ 0s 11ms/step - accuracy: 0.9082 - loss: 0.2702
Epoch 29: val_accuracy did not improve from 0.86682
121/121 ━━━━━━━━━━━━━━━━━━━━ 3s 11ms/step - accuracy: 0.9081 - loss: 0.2702 - val_accuracy: 0.8107 - val_loss: 0.7366
Epoch 30/50
116/121 ━━━━━━━━━━━━━━━━━━━━ 0s 11ms/step - accuracy: 0.8984 - loss: 0.2612
Epoch 30: val_accuracy did not improve from 0.86682
121/121 ━━━━━━━━━━━━━━━━━━━━ 3s 12ms/step - accuracy: 0.8983 - loss: 0.2613 - val_accuracy: 0.7196 - val_loss: 1.2780
Epoch 31/50
120/121 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - accuracy: 0.9113 - loss: 0.2497
Epoch 31: val_accuracy improved from 0.86682 to 0.87850, saving model to /content/drive/MyDrive/AI_ML_PGP/Projects/PlantSeedlingsClassification/cnn_weighted_model.keras
121/121 ━━━━━━━━━━━━━━━━━━━━ 3s 14ms/step - accuracy: 0.9112 - loss: 0.2497 - val_accuracy: 0.8785 - val_loss: 0.4082
Epoch 32/50
120/121 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - accuracy: 0.9151 - loss: 0.2489
Epoch 32: val_accuracy did not improve from 0.87850
121/121 ━━━━━━━━━━━━━━━━━━━━ 2s 12ms/step - accuracy: 0.9151 - loss: 0.2488 - val_accuracy: 0.8715 - val_loss: 0.4610
Epoch 33/50
116/121 ━━━━━━━━━━━━━━━━━━━━ 0s 11ms/step - accuracy: 0.9104 - loss: 0.2418
Epoch 33: val_accuracy did not improve from 0.87850
121/121 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - accuracy: 0.9103 - loss: 0.2420 - val_accuracy: 0.7500 - val_loss: 0.9120
Epoch 34/50
116/121 ━━━━━━━━━━━━━━━━━━━━ 0s 11ms/step - accuracy: 0.9129 - loss: 0.2388
Epoch 34: val_accuracy did not improve from 0.87850
121/121 ━━━━━━━━━━━━━━━━━━━━ 3s 11ms/step - accuracy: 0.9133 - loss: 0.2380 - val_accuracy: 0.5724 - val_loss: 3.1023
Epoch 35/50
116/121 ━━━━━━━━━━━━━━━━━━━━ 0s 11ms/step - accuracy: 0.9223 - loss: 0.2346
Epoch 35: val_accuracy did not improve from 0.87850
121/121 ━━━━━━━━━━━━━━━━━━━━ 1s 11ms/step - accuracy: 0.9223 - loss: 0.2336 - val_accuracy: 0.5771 - val_loss: 2.9692
Epoch 36/50
116/121 ━━━━━━━━━━━━━━━━━━━━ 0s 11ms/step - accuracy: 0.9170 - loss: 0.2440
Epoch 36: val_accuracy did not improve from 0.87850
121/121 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - accuracy: 0.9173 - loss: 0.2432 - val_accuracy: 0.4813 - val_loss: 3.5863
Epoch 37/50
116/121 ━━━━━━━━━━━━━━━━━━━━ 0s 11ms/step - accuracy: 0.9290 - loss: 0.2080
Epoch 37: val_accuracy did not improve from 0.87850
121/121 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - accuracy: 0.9289 - loss: 0.2079 - val_accuracy: 0.6098 - val_loss: 2.1593
Epoch 38/50
120/121 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - accuracy: 0.9404 - loss: 0.1745
Epoch 38: val_accuracy did not improve from 0.87850
121/121 ━━━━━━━━━━━━━━━━━━━━ 3s 13ms/step - accuracy: 0.9403 - loss: 0.1747 - val_accuracy: 0.5491 - val_loss: 3.5078
Epoch 39/50
120/121 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - accuracy: 0.9414 - loss: 0.1815
Epoch 39: val_accuracy did not improve from 0.87850
121/121 ━━━━━━━━━━━━━━━━━━━━ 3s 13ms/step - accuracy: 0.9412 - loss: 0.1819 - val_accuracy: 0.3995 - val_loss: 4.7704
Epoch 40/50
116/121 ━━━━━━━━━━━━━━━━━━━━ 0s 11ms/step - accuracy: 0.9240 - loss: 0.2437
Epoch 40: val_accuracy did not improve from 0.87850
121/121 ━━━━━━━━━━━━━━━━━━━━ 2s 11ms/step - accuracy: 0.9243 - loss: 0.2419 - val_accuracy: 0.8551 - val_loss: 0.6812
Epoch 41/50
121/121 ━━━━━━━━━━━━━━━━━━━━ 0s 11ms/step - accuracy: 0.9341 - loss: 0.1891
Epoch 41: val_accuracy did not improve from 0.87850
121/121 ━━━━━━━━━━━━━━━━━━━━ 1s 11ms/step - accuracy: 0.9341 - loss: 0.1891 - val_accuracy: 0.4509 - val_loss: 4.7781
Epoch 41: early stopping
Restoring model weights from the end of the best epoch: 31.
Plotting epochs vs accuracy¶
In [ ]:
plt.figure(figsize=(8,8))
plt.plot(history.history['accuracy'])
plt.plot(history.history['val_accuracy'])
plt.title('Model Accuracy')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend(['Train', 'Validation'], loc='upper left')
plt.show()
No description has been provided for this image

Loading saved CNN model¶

This allows us to use the saved model in previous step and save the time and cost for re-training from scratch each time we run the notebook

In [78]:
cnn_model = tf.keras.models.load_model('/content/drive/MyDrive/AI_ML_PGP/Projects/PlantSeedlingsClassification/cnn_weighted_model.keras')

Model evaluation and performance¶

In [79]:
# evaluating the model
cnn_model.evaluate(X_test_norm, y_test_encoded)
15/15 ━━━━━━━━━━━━━━━━━━━━ 6s 319ms/step - accuracy: 0.8624 - loss: 0.4596
Out[79]:
[0.5718783140182495, 0.8378947377204895]
In [80]:
# Test prediction and accuracy
y_pred = cnn_model.predict(X_test_norm)
y_pred_classes = np.argmax(y_pred, axis=1)
y_test_classes = np.argmax(y_test_encoded, axis=1)
accuracy_score(y_test_classes, y_pred_classes)
15/15 ━━━━━━━━━━━━━━━━━━━━ 5s 327ms/step
Out[80]:
0.8378947368421052
In [81]:
# Confusion Matrix
cm = confusion_matrix(y_test_classes, y_pred_classes)
cf_matrix_norm = cm / np.sum(cm, axis=1)
plt.figure(figsize=(10,10))
sns.heatmap(cf_matrix_norm, annot=True, xticklabels=categories, yticklabels=categories, cmap='coolwarm')
plt.title('Confusion Matrix')
plt.ylabel('True label')
plt.xlabel('Predicted label')
plt.show()
No description has been provided for this image
In [82]:
class_report = classification_report(y_test_classes, y_pred_classes, target_names=categories)
print(class_report)
                           precision    recall  f1-score   support

              Black-grass       0.51      0.73      0.60        26
                 Charlock       0.84      0.92      0.88        39
                 Cleavers       0.81      0.86      0.83        29
         Common Chickweed       0.96      0.87      0.91        61
             Common wheat       0.79      0.86      0.83        22
                  Fat Hen       0.92      0.98      0.95        48
         Loose Silky-bent       0.83      0.77      0.80        65
                    Maize       0.95      0.91      0.93        22
        Scentless Mayweed       0.79      0.81      0.80        52
          Shepherds Purse       0.55      0.48      0.51        23
Small-flowered Cranesbill       0.98      0.88      0.93        50
               Sugar beet       0.91      0.84      0.88        38

                 accuracy                           0.84       475
                macro avg       0.82      0.83      0.82       475
             weighted avg       0.85      0.84      0.84       475

Observations¶

  • Above we created CNN model and we used class weight to give more importance to the minority classes.
  • Our model has 4 convolutional layer and 3 fully connected layer
  • It has total 182,508 params and among thsoe 181,996 are trainable
  • We stopped training the model if the validation data set loss doesn't decrease after 10 epochs and saved the best model to a file.
  • With the above configuration our model stopped training after epoch 41 and restoring model weights from the end of the best epoch: 31.
  • The model achieved around 91% training accuracy and 87% validation accuracy at the end of the epoch 31.
  • We used this model to evaluate against the test dataset and make prediction. The model achieved 83.7% accuracy against the test data set.
  • Though the overall score is promising, the model slightly overfits.
  • From confusion matrix we observe that only 48% Shepherds Purse are correctly identified
  • From the classification report we observe that the model has overall 84% weighted accurary.
  • The recall for Shepherd Purse is just 48%, suggesting many Shephered Purse samples are wrongly mis-classified.

Model Performance Improvement¶

Performance improvement with ReduceLROnPlateau and SGD¶

  • We will try to improve the performance of the model by reducing the learning rate if accuracy stops improving.
  • We will use ReduceLROnPlateau along with EarlyStopping

        ReduceLROnPlateau: Models often benefit from reducing the learning rate by a factor of 2-10 once learning stagnates. This callback monitors a quantity and if no improvement is seen for a 'patience' number of epochs, the learning rate is reduced. Here is the documentation

  • We used adam optimizer in the previous step, however, ReduceLROnPlateau is more effective when we use SGD. SGD uses global learning rate. When learning rate reduces it impacts the step size.
  • We will use same architecture that we used in previous step

Model training¶

In [ ]:
# ReduceLROnPlateau
reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.1, verbose=1, patience=5, min_lr=1e-6)
es = EarlyStopping(monitor='val_loss', mode='min', verbose=1, patience=10, restore_best_weights=True)
mc = ModelCheckpoint('/content/drive/MyDrive/AI_ML_PGP/Projects/PlantSeedlingsClassification/cnn_model_sgd.keras', monitor='val_accuracy', mode='max', verbose=1, save_best_only=True)
In [ ]:
# Loading the previous model (saved CNN model) architecture, reset the weight and train it with reduce_lr
reset_session()
# Clone the saved cnn_model
model = clone_model(cnn_model)
# Build a new model by resetting the weight of the cnn model with the same input shape
model.build(cnn_model.input_shape)
#optimizer
optimizer = SGD(learning_rate=0.01, momentum=0.9)
# compile
model.compile(optimizer = optimizer, loss='categorical_crossentropy', metrics=['accuracy'])
In [ ]:
history = model.fit(X_train_norm, y_train_encoded, batch_size=32, epochs=50, validation_split=0.1, class_weight=class_weight_dict, callbacks=[es, mc, reduce_lr])
Epoch 1/50
121/121 ━━━━━━━━━━━━━━━━━━━━ 0s 1s/step - accuracy: 0.0858 - loss: 2.5916
Epoch 1: val_accuracy improved from -inf to 0.12617, saving model to /content/drive/MyDrive/AI_ML_PGP/Projects/PlantSeedlingsClassification/cnn_model_sgd.keras
121/121 ━━━━━━━━━━━━━━━━━━━━ 168s 1s/step - accuracy: 0.0858 - loss: 2.5909 - val_accuracy: 0.1262 - val_loss: 2.4765 - learning_rate: 0.0100
Epoch 2/50
121/121 ━━━━━━━━━━━━━━━━━━━━ 0s 1s/step - accuracy: 0.1025 - loss: 2.4728
Epoch 2: val_accuracy improved from 0.12617 to 0.14953, saving model to /content/drive/MyDrive/AI_ML_PGP/Projects/PlantSeedlingsClassification/cnn_model_sgd.keras
121/121 ━━━━━━━━━━━━━━━━━━━━ 159s 1s/step - accuracy: 0.1027 - loss: 2.4722 - val_accuracy: 0.1495 - val_loss: 2.3697 - learning_rate: 0.0100
Epoch 3/50
121/121 ━━━━━━━━━━━━━━━━━━━━ 0s 1s/step - accuracy: 0.1800 - loss: 2.1614
Epoch 3: val_accuracy improved from 0.14953 to 0.35748, saving model to /content/drive/MyDrive/AI_ML_PGP/Projects/PlantSeedlingsClassification/cnn_model_sgd.keras
121/121 ━━━━━━━━━━━━━━━━━━━━ 159s 1s/step - accuracy: 0.1802 - loss: 2.1608 - val_accuracy: 0.3575 - val_loss: 2.0012 - learning_rate: 0.0100
Epoch 4/50
121/121 ━━━━━━━━━━━━━━━━━━━━ 0s 1s/step - accuracy: 0.2617 - loss: 1.9366
Epoch 4: val_accuracy did not improve from 0.35748
121/121 ━━━━━━━━━━━━━━━━━━━━ 199s 1s/step - accuracy: 0.2619 - loss: 1.9361 - val_accuracy: 0.2593 - val_loss: 2.1196 - learning_rate: 0.0100
Epoch 5/50
121/121 ━━━━━━━━━━━━━━━━━━━━ 0s 1s/step - accuracy: 0.3410 - loss: 1.7588
Epoch 5: val_accuracy did not improve from 0.35748
121/121 ━━━━━━━━━━━━━━━━━━━━ 201s 1s/step - accuracy: 0.3413 - loss: 1.7582 - val_accuracy: 0.3551 - val_loss: 1.7253 - learning_rate: 0.0100
Epoch 6/50
121/121 ━━━━━━━━━━━━━━━━━━━━ 0s 1s/step - accuracy: 0.3981 - loss: 1.6208
Epoch 6: val_accuracy improved from 0.35748 to 0.41822, saving model to /content/drive/MyDrive/AI_ML_PGP/Projects/PlantSeedlingsClassification/cnn_model_sgd.keras
121/121 ━━━━━━━━━━━━━━━━━━━━ 204s 1s/step - accuracy: 0.3983 - loss: 1.6204 - val_accuracy: 0.4182 - val_loss: 1.7061 - learning_rate: 0.0100
Epoch 7/50
121/121 ━━━━━━━━━━━━━━━━━━━━ 0s 1s/step - accuracy: 0.4395 - loss: 1.4883
Epoch 7: val_accuracy improved from 0.41822 to 0.50701, saving model to /content/drive/MyDrive/AI_ML_PGP/Projects/PlantSeedlingsClassification/cnn_model_sgd.keras
121/121 ━━━━━━━━━━━━━━━━━━━━ 155s 1s/step - accuracy: 0.4397 - loss: 1.4880 - val_accuracy: 0.5070 - val_loss: 1.5595 - learning_rate: 0.0100
Epoch 8/50
121/121 ━━━━━━━━━━━━━━━━━━━━ 0s 1s/step - accuracy: 0.4719 - loss: 1.4092
Epoch 8: val_accuracy did not improve from 0.50701
121/121 ━━━━━━━━━━━━━━━━━━━━ 158s 1s/step - accuracy: 0.4719 - loss: 1.4089 - val_accuracy: 0.4907 - val_loss: 1.4382 - learning_rate: 0.0100
Epoch 9/50
121/121 ━━━━━━━━━━━━━━━━━━━━ 0s 1s/step - accuracy: 0.5134 - loss: 1.3060
Epoch 9: val_accuracy did not improve from 0.50701
121/121 ━━━━━━━━━━━━━━━━━━━━ 216s 1s/step - accuracy: 0.5135 - loss: 1.3059 - val_accuracy: 0.3621 - val_loss: 2.4334 - learning_rate: 0.0100
Epoch 10/50
121/121 ━━━━━━━━━━━━━━━━━━━━ 0s 1s/step - accuracy: 0.5713 - loss: 1.1898
Epoch 10: val_accuracy improved from 0.50701 to 0.55140, saving model to /content/drive/MyDrive/AI_ML_PGP/Projects/PlantSeedlingsClassification/cnn_model_sgd.keras
121/121 ━━━━━━━━━━━━━━━━━━━━ 188s 1s/step - accuracy: 0.5714 - loss: 1.1897 - val_accuracy: 0.5514 - val_loss: 1.3687 - learning_rate: 0.0100
Epoch 11/50
121/121 ━━━━━━━━━━━━━━━━━━━━ 0s 1s/step - accuracy: 0.6021 - loss: 1.1258
Epoch 11: val_accuracy did not improve from 0.55140
121/121 ━━━━━━━━━━━━━━━━━━━━ 199s 1s/step - accuracy: 0.6021 - loss: 1.1258 - val_accuracy: 0.5444 - val_loss: 1.3503 - learning_rate: 0.0100
Epoch 12/50
121/121 ━━━━━━━━━━━━━━━━━━━━ 0s 1s/step - accuracy: 0.5956 - loss: 1.1230
Epoch 12: val_accuracy did not improve from 0.55140
121/121 ━━━━━━━━━━━━━━━━━━━━ 202s 1s/step - accuracy: 0.5958 - loss: 1.1225 - val_accuracy: 0.4486 - val_loss: 1.5078 - learning_rate: 0.0100
Epoch 13/50
121/121 ━━━━━━━━━━━━━━━━━━━━ 0s 1s/step - accuracy: 0.6368 - loss: 0.9850
Epoch 13: val_accuracy did not improve from 0.55140
121/121 ━━━━━━━━━━━━━━━━━━━━ 203s 1s/step - accuracy: 0.6369 - loss: 0.9851 - val_accuracy: 0.5023 - val_loss: 2.0224 - learning_rate: 0.0100
Epoch 14/50
121/121 ━━━━━━━━━━━━━━━━━━━━ 0s 1s/step - accuracy: 0.6735 - loss: 0.9614
Epoch 14: val_accuracy did not improve from 0.55140
121/121 ━━━━━━━━━━━━━━━━━━━━ 202s 1s/step - accuracy: 0.6736 - loss: 0.9609 - val_accuracy: 0.4112 - val_loss: 1.6038 - learning_rate: 0.0100
Epoch 15/50
121/121 ━━━━━━━━━━━━━━━━━━━━ 0s 1s/step - accuracy: 0.6410 - loss: 0.9854
Epoch 15: val_accuracy improved from 0.55140 to 0.71963, saving model to /content/drive/MyDrive/AI_ML_PGP/Projects/PlantSeedlingsClassification/cnn_model_sgd.keras
121/121 ━━━━━━━━━━━━━━━━━━━━ 201s 1s/step - accuracy: 0.6412 - loss: 0.9849 - val_accuracy: 0.7196 - val_loss: 0.8812 - learning_rate: 0.0100
Epoch 16/50
121/121 ━━━━━━━━━━━━━━━━━━━━ 0s 1s/step - accuracy: 0.6967 - loss: 0.7940
Epoch 16: val_accuracy did not improve from 0.71963
121/121 ━━━━━━━━━━━━━━━━━━━━ 159s 1s/step - accuracy: 0.6968 - loss: 0.7939 - val_accuracy: 0.6822 - val_loss: 0.8639 - learning_rate: 0.0100
Epoch 17/50
121/121 ━━━━━━━━━━━━━━━━━━━━ 0s 1s/step - accuracy: 0.7517 - loss: 0.7274
Epoch 17: val_accuracy did not improve from 0.71963
121/121 ━━━━━━━━━━━━━━━━━━━━ 200s 1s/step - accuracy: 0.7516 - loss: 0.7273 - val_accuracy: 0.5514 - val_loss: 1.5773 - learning_rate: 0.0100
Epoch 18/50
121/121 ━━━━━━━━━━━━━━━━━━━━ 0s 1s/step - accuracy: 0.7364 - loss: 0.6618
Epoch 18: val_accuracy did not improve from 0.71963
121/121 ━━━━━━━━━━━━━━━━━━━━ 201s 1s/step - accuracy: 0.7365 - loss: 0.6619 - val_accuracy: 0.7103 - val_loss: 0.9254 - learning_rate: 0.0100
Epoch 19/50
121/121 ━━━━━━━━━━━━━━━━━━━━ 0s 1s/step - accuracy: 0.7772 - loss: 0.6371
Epoch 19: val_accuracy did not improve from 0.71963
121/121 ━━━━━━━━━━━━━━━━━━━━ 201s 1s/step - accuracy: 0.7771 - loss: 0.6373 - val_accuracy: 0.6636 - val_loss: 1.1437 - learning_rate: 0.0100
Epoch 20/50
121/121 ━━━━━━━━━━━━━━━━━━━━ 0s 1s/step - accuracy: 0.7737 - loss: 0.6401
Epoch 20: val_accuracy did not improve from 0.71963
121/121 ━━━━━━━━━━━━━━━━━━━━ 201s 1s/step - accuracy: 0.7737 - loss: 0.6400 - val_accuracy: 0.6682 - val_loss: 1.2079 - learning_rate: 0.0100
Epoch 21/50
121/121 ━━━━━━━━━━━━━━━━━━━━ 0s 1s/step - accuracy: 0.8072 - loss: 0.5226
Epoch 21: val_accuracy improved from 0.71963 to 0.79206, saving model to /content/drive/MyDrive/AI_ML_PGP/Projects/PlantSeedlingsClassification/cnn_model_sgd.keras
121/121 ━━━━━━━━━━━━━━━━━━━━ 204s 1s/step - accuracy: 0.8071 - loss: 0.5228 - val_accuracy: 0.7921 - val_loss: 0.7936 - learning_rate: 0.0100
Epoch 22/50
121/121 ━━━━━━━━━━━━━━━━━━━━ 0s 1s/step - accuracy: 0.8109 - loss: 0.5129
Epoch 22: val_accuracy did not improve from 0.79206
121/121 ━━━━━━━━━━━━━━━━━━━━ 209s 1s/step - accuracy: 0.8109 - loss: 0.5132 - val_accuracy: 0.6986 - val_loss: 1.0019 - learning_rate: 0.0100
Epoch 23/50
121/121 ━━━━━━━━━━━━━━━━━━━━ 0s 1s/step - accuracy: 0.8072 - loss: 0.5177
Epoch 23: val_accuracy did not improve from 0.79206
121/121 ━━━━━━━━━━━━━━━━━━━━ 201s 1s/step - accuracy: 0.8073 - loss: 0.5175 - val_accuracy: 0.7570 - val_loss: 0.9990 - learning_rate: 0.0100
Epoch 24/50
121/121 ━━━━━━━━━━━━━━━━━━━━ 0s 1s/step - accuracy: 0.8321 - loss: 0.4381
Epoch 24: val_accuracy did not improve from 0.79206
121/121 ━━━━━━━━━━━━━━━━━━━━ 199s 1s/step - accuracy: 0.8320 - loss: 0.4384 - val_accuracy: 0.6729 - val_loss: 1.2788 - learning_rate: 0.0100
Epoch 25/50
121/121 ━━━━━━━━━━━━━━━━━━━━ 0s 1s/step - accuracy: 0.8202 - loss: 0.4593
Epoch 25: val_accuracy improved from 0.79206 to 0.80140, saving model to /content/drive/MyDrive/AI_ML_PGP/Projects/PlantSeedlingsClassification/cnn_model_sgd.keras
121/121 ━━━━━━━━━━━━━━━━━━━━ 202s 1s/step - accuracy: 0.8204 - loss: 0.4590 - val_accuracy: 0.8014 - val_loss: 0.7177 - learning_rate: 0.0100
Epoch 26/50
121/121 ━━━━━━━━━━━━━━━━━━━━ 0s 1s/step - accuracy: 0.8509 - loss: 0.4051
Epoch 26: val_accuracy did not improve from 0.80140
121/121 ━━━━━━━━━━━━━━━━━━━━ 203s 1s/step - accuracy: 0.8509 - loss: 0.4051 - val_accuracy: 0.7757 - val_loss: 0.8207 - learning_rate: 0.0100
Epoch 27/50
121/121 ━━━━━━━━━━━━━━━━━━━━ 0s 1s/step - accuracy: 0.8731 - loss: 0.3687
Epoch 27: val_accuracy did not improve from 0.80140
121/121 ━━━━━━━━━━━━━━━━━━━━ 158s 1s/step - accuracy: 0.8731 - loss: 0.3687 - val_accuracy: 0.7453 - val_loss: 0.8860 - learning_rate: 0.0100
Epoch 28/50
121/121 ━━━━━━━━━━━━━━━━━━━━ 0s 1s/step - accuracy: 0.8513 - loss: 0.3757
Epoch 28: val_accuracy did not improve from 0.80140
121/121 ━━━━━━━━━━━━━━━━━━━━ 207s 1s/step - accuracy: 0.8512 - loss: 0.3760 - val_accuracy: 0.7827 - val_loss: 0.7496 - learning_rate: 0.0100
Epoch 29/50
121/121 ━━━━━━━━━━━━━━━━━━━━ 0s 1s/step - accuracy: 0.8449 - loss: 0.4293
Epoch 29: val_accuracy did not improve from 0.80140
121/121 ━━━━━━━━━━━━━━━━━━━━ 196s 1s/step - accuracy: 0.8450 - loss: 0.4289 - val_accuracy: 0.7664 - val_loss: 0.7873 - learning_rate: 0.0100
Epoch 30/50
121/121 ━━━━━━━━━━━━━━━━━━━━ 0s 1s/step - accuracy: 0.8774 - loss: 0.3189
Epoch 30: val_accuracy did not improve from 0.80140

Epoch 30: ReduceLROnPlateau reducing learning rate to 0.0009999999776482583.
121/121 ━━━━━━━━━━━━━━━━━━━━ 201s 1s/step - accuracy: 0.8774 - loss: 0.3191 - val_accuracy: 0.7103 - val_loss: 1.1046 - learning_rate: 0.0100
Epoch 31/50
121/121 ━━━━━━━━━━━━━━━━━━━━ 0s 1s/step - accuracy: 0.8724 - loss: 0.3660
Epoch 31: val_accuracy improved from 0.80140 to 0.85280, saving model to /content/drive/MyDrive/AI_ML_PGP/Projects/PlantSeedlingsClassification/cnn_model_sgd.keras
121/121 ━━━━━━━━━━━━━━━━━━━━ 205s 1s/step - accuracy: 0.8725 - loss: 0.3657 - val_accuracy: 0.8528 - val_loss: 0.4595 - learning_rate: 1.0000e-03
Epoch 32/50
121/121 ━━━━━━━━━━━━━━━━━━━━ 0s 1s/step - accuracy: 0.9023 - loss: 0.2489
Epoch 32: val_accuracy improved from 0.85280 to 0.88318, saving model to /content/drive/MyDrive/AI_ML_PGP/Projects/PlantSeedlingsClassification/cnn_model_sgd.keras
121/121 ━━━━━━━━━━━━━━━━━━━━ 155s 1s/step - accuracy: 0.9023 - loss: 0.2488 - val_accuracy: 0.8832 - val_loss: 0.4247 - learning_rate: 1.0000e-03
Epoch 33/50
121/121 ━━━━━━━━━━━━━━━━━━━━ 0s 1s/step - accuracy: 0.9105 - loss: 0.2434
Epoch 33: val_accuracy did not improve from 0.88318
121/121 ━━━━━━━━━━━━━━━━━━━━ 203s 1s/step - accuracy: 0.9105 - loss: 0.2433 - val_accuracy: 0.8832 - val_loss: 0.4104 - learning_rate: 1.0000e-03
Epoch 34/50
121/121 ━━━━━━━━━━━━━━━━━━━━ 0s 1s/step - accuracy: 0.9129 - loss: 0.2324
Epoch 34: val_accuracy did not improve from 0.88318
121/121 ━━━━━━━━━━━━━━━━━━━━ 200s 1s/step - accuracy: 0.9129 - loss: 0.2323 - val_accuracy: 0.8832 - val_loss: 0.4046 - learning_rate: 1.0000e-03
Epoch 35/50
121/121 ━━━━━━━━━━━━━━━━━━━━ 0s 1s/step - accuracy: 0.9237 - loss: 0.2002
Epoch 35: val_accuracy did not improve from 0.88318
121/121 ━━━━━━━━━━━━━━━━━━━━ 160s 1s/step - accuracy: 0.9237 - loss: 0.2002 - val_accuracy: 0.8668 - val_loss: 0.4097 - learning_rate: 1.0000e-03
Epoch 36/50
121/121 ━━━━━━━━━━━━━━━━━━━━ 0s 1s/step - accuracy: 0.9222 - loss: 0.2154
Epoch 36: val_accuracy did not improve from 0.88318
121/121 ━━━━━━━━━━━━━━━━━━━━ 198s 1s/step - accuracy: 0.9222 - loss: 0.2153 - val_accuracy: 0.8832 - val_loss: 0.3963 - learning_rate: 1.0000e-03
Epoch 37/50
121/121 ━━━━━━━━━━━━━━━━━━━━ 0s 1s/step - accuracy: 0.9279 - loss: 0.1883
Epoch 37: val_accuracy did not improve from 0.88318
121/121 ━━━━━━━━━━━━━━━━━━━━ 202s 1s/step - accuracy: 0.9279 - loss: 0.1883 - val_accuracy: 0.8785 - val_loss: 0.4174 - learning_rate: 1.0000e-03
Epoch 38/50
121/121 ━━━━━━━━━━━━━━━━━━━━ 0s 1s/step - accuracy: 0.9235 - loss: 0.1966
Epoch 38: val_accuracy improved from 0.88318 to 0.88551, saving model to /content/drive/MyDrive/AI_ML_PGP/Projects/PlantSeedlingsClassification/cnn_model_sgd.keras
121/121 ━━━━━━━━━━━━━━━━━━━━ 202s 1s/step - accuracy: 0.9235 - loss: 0.1965 - val_accuracy: 0.8855 - val_loss: 0.4338 - learning_rate: 1.0000e-03
Epoch 39/50
121/121 ━━━━━━━━━━━━━━━━━━━━ 0s 1s/step - accuracy: 0.9317 - loss: 0.1777
Epoch 39: val_accuracy did not improve from 0.88551
121/121 ━━━━━━━━━━━━━━━━━━━━ 160s 1s/step - accuracy: 0.9317 - loss: 0.1777 - val_accuracy: 0.8855 - val_loss: 0.4189 - learning_rate: 1.0000e-03
Epoch 40/50
121/121 ━━━━━━━━━━━━━━━━━━━━ 0s 1s/step - accuracy: 0.9257 - loss: 0.1851
Epoch 40: val_accuracy did not improve from 0.88551
121/121 ━━━━━━━━━━━━━━━━━━━━ 204s 1s/step - accuracy: 0.9257 - loss: 0.1850 - val_accuracy: 0.8785 - val_loss: 0.4226 - learning_rate: 1.0000e-03
Epoch 41/50
121/121 ━━━━━━━━━━━━━━━━━━━━ 0s 1s/step - accuracy: 0.9394 - loss: 0.1647
Epoch 41: val_accuracy did not improve from 0.88551

Epoch 41: ReduceLROnPlateau reducing learning rate to 9.999999310821295e-05.
121/121 ━━━━━━━━━━━━━━━━━━━━ 200s 1s/step - accuracy: 0.9393 - loss: 0.1648 - val_accuracy: 0.8855 - val_loss: 0.4172 - learning_rate: 1.0000e-03
Epoch 42/50
121/121 ━━━━━━━━━━━━━━━━━━━━ 0s 1s/step - accuracy: 0.9364 - loss: 0.1636
Epoch 42: val_accuracy did not improve from 0.88551
121/121 ━━━━━━━━━━━━━━━━━━━━ 200s 1s/step - accuracy: 0.9365 - loss: 0.1637 - val_accuracy: 0.8832 - val_loss: 0.4074 - learning_rate: 1.0000e-04
Epoch 43/50
121/121 ━━━━━━━━━━━━━━━━━━━━ 0s 1s/step - accuracy: 0.9366 - loss: 0.1715
Epoch 43: val_accuracy did not improve from 0.88551
121/121 ━━━━━━━━━━━━━━━━━━━━ 202s 1s/step - accuracy: 0.9366 - loss: 0.1715 - val_accuracy: 0.8808 - val_loss: 0.4092 - learning_rate: 1.0000e-04
Epoch 44/50
121/121 ━━━━━━━━━━━━━━━━━━━━ 0s 1s/step - accuracy: 0.9302 - loss: 0.1725
Epoch 44: val_accuracy did not improve from 0.88551
121/121 ━━━━━━━━━━━━━━━━━━━━ 201s 1s/step - accuracy: 0.9302 - loss: 0.1725 - val_accuracy: 0.8785 - val_loss: 0.4112 - learning_rate: 1.0000e-04
Epoch 45/50
121/121 ━━━━━━━━━━━━━━━━━━━━ 0s 1s/step - accuracy: 0.9315 - loss: 0.1683
Epoch 45: val_accuracy did not improve from 0.88551
121/121 ━━━━━━━━━━━━━━━━━━━━ 204s 1s/step - accuracy: 0.9315 - loss: 0.1684 - val_accuracy: 0.8808 - val_loss: 0.4093 - learning_rate: 1.0000e-04
Epoch 46/50
121/121 ━━━━━━━━━━━━━━━━━━━━ 0s 1s/step - accuracy: 0.9348 - loss: 0.1667
Epoch 46: val_accuracy did not improve from 0.88551

Epoch 46: ReduceLROnPlateau reducing learning rate to 9.999999019782991e-06.
121/121 ━━━━━━━━━━━━━━━━━━━━ 158s 1s/step - accuracy: 0.9348 - loss: 0.1667 - val_accuracy: 0.8808 - val_loss: 0.4090 - learning_rate: 1.0000e-04
Epoch 46: early stopping
Restoring model weights from the end of the best epoch: 36.
Plotting epochs vs accuracy¶
In [ ]:
plt.figure(figsize=(8,8))
plt.plot(history.history['accuracy'])
plt.plot(history.history['val_accuracy'])
plt.title('Model Accuracy')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend(['Train', 'Validation'], loc='upper left')
plt.show()
No description has been provided for this image

Loading the saved model¶

In [83]:
cnn_model_sgd = tf.keras.models.load_model('/content/drive/MyDrive/AI_ML_PGP/Projects/PlantSeedlingsClassification/cnn_model_sgd.keras')

Model evaluation and performance¶

In [84]:
# Model evaluation
cnn_model_sgd.evaluate(X_test_norm, y_test_encoded)
15/15 ━━━━━━━━━━━━━━━━━━━━ 3s 193ms/step - accuracy: 0.9169 - loss: 0.3877
Out[84]:
[0.5075504183769226, 0.8926315903663635]
In [85]:
# Test prediction and accuracy
y_pred = cnn_model_sgd.predict(X_test_norm)
y_pred_classes = np.argmax(y_pred, axis=1)
y_test_classes = np.argmax(y_test_encoded, axis=1)
accuracy_score(y_test_classes, y_pred_classes)
15/15 ━━━━━━━━━━━━━━━━━━━━ 3s 198ms/step
Out[85]:
0.8926315789473684
In [86]:
# Confusion matrix
cm = confusion_matrix(y_test_classes, y_pred_classes)
cf_matrix_norm = cm / np.sum(cm, axis=1)
plt.figure(figsize=(10,10))
sns.heatmap(cf_matrix_norm, annot=True, xticklabels=categories, yticklabels=categories, cmap='coolwarm')
plt.title('Confusion Matrix')
plt.ylabel('True label')
plt.xlabel('Predicted label')
plt.show()
No description has been provided for this image
In [87]:
# Classification report
class_report = classification_report(y_test_classes, y_pred_classes, target_names=categories)
print(class_report)
                           precision    recall  f1-score   support

              Black-grass       0.57      0.65      0.61        26
                 Charlock       0.97      0.97      0.97        39
                 Cleavers       0.93      0.93      0.93        29
         Common Chickweed       0.95      0.97      0.96        61
             Common wheat       0.90      0.86      0.88        22
                  Fat Hen       1.00      0.94      0.97        48
         Loose Silky-bent       0.85      0.85      0.85        65
                    Maize       0.85      1.00      0.92        22
        Scentless Mayweed       0.89      0.90      0.90        52
          Shepherds Purse       0.76      0.70      0.73        23
Small-flowered Cranesbill       0.90      0.90      0.90        50
               Sugar beet       1.00      0.89      0.94        38

                 accuracy                           0.89       475
                macro avg       0.88      0.88      0.88       475
             weighted avg       0.90      0.89      0.89       475

Observations¶

  • We observe that after tuning with SGD optimizer and ReduceLROnPlateau the model performance improved both in the training and validation set.
  • The best model has around 92% accuracy in training dataset and around 88% accuracy in the validation dataset.
  • The model is performing well in the test dataset with around 89% accuracy. Model is generalizing well.
  • From the confusion matrix and classification report we observe that the recall for Black-Grass is reduced to 65% from 73%, but, the recall for Shepherds Purse improved from 48% to 70%.
  • The weigted average accuracy is 89%

Data Augmentation¶

  • Previous models only used class weights to balance the dataset. It gave higher importance to under-represented classes but didn't add any synthetic data.
  • We can combine data augmentation with class weight. Data Augementation will add synthetic data and class wieght will add more importance to the under-represented classes.

Augmenation configuration¶

In [88]:
reset_session()
In [89]:
# Define data augmentation settings
# We will reuse the normalized dataset, so we are not going to rescale the images
aug_settings = ImageDataGenerator(
    horizontal_flip = True,
    vertical_flip = False,
    rotation_range = 20,
    shear_range = 0.1,
    zoom_range = 0.1,
    width_shift_range = 0.1,
    height_shift_range = 0.1,
)

We have to split the train dataset for validation as we can't use validation_split inside the fit(). We are going to fit the model with train_generator and val_generator instead of X_train_norm and y_train_encoded arrays.

In [90]:
# Split data into training and validation sets
X_train_aug, X_val_aug, y_train_aug, y_val_aug = train_test_split(
    X_train_norm,
    y_train_encoded,
    test_size=0.1,  # 10% for validation
    random_state = random_state,
    stratify = y_train_encoded
)

# Creates train augmented generator
train_generator = aug_settings.flow(
    X_train_aug,
    y_train_aug,
    batch_size=32,
    shuffle=True)
val_generator = aug_settings.flow(
    X_val_aug,
    y_val_aug,
    batch_size=32,
    shuffle=False)

Model Training¶

In [91]:
# ReduceLROnPlateau, EarlyStopping and ModelCheckpoint
reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.1, verbose=1, patience=5, min_lr=1e-6)
es = EarlyStopping(monitor='val_loss', mode='min', verbose=1, patience=10, restore_best_weights=True)
mc = ModelCheckpoint('/content/drive/MyDrive/AI_ML_PGP/Projects/PlantSeedlingsClassification/cnn_model_sgd_aug.keras', monitor='val_accuracy', mode='max', verbose=1, save_best_only=True)
In [92]:
# Clone the saved cnn_model_sgd
model = clone_model(cnn_model_sgd)
# Build a new model by resetting the weight of the cnn_model_sgd with the same input shape
model.build(cnn_model_sgd.input_shape)
#optimizer
optimizer = SGD(learning_rate=0.01, momentum=0.9)
# compile
model.compile(optimizer = optimizer, loss='categorical_crossentropy', metrics=['accuracy'])
In [93]:
history = model.fit(
    train_generator,
    epochs=50,
    steps_per_epoch=len(X_train_aug)//32,
    validation_data = val_generator,
    validation_steps = len(X_val_aug)//32,
    class_weight=class_weight_dict,
    callbacks=[es, mc, reduce_lr]
)
Epoch 1/50
120/120 ━━━━━━━━━━━━━━━━━━━━ 0s 1s/step - accuracy: 0.0957 - loss: 2.5824
Epoch 1: val_accuracy improved from -inf to 0.10337, saving model to /content/drive/MyDrive/AI_ML_PGP/Projects/PlantSeedlingsClassification/cnn_model_sgd_aug.keras
120/120 ━━━━━━━━━━━━━━━━━━━━ 147s 1s/step - accuracy: 0.0958 - loss: 2.5818 - val_accuracy: 0.1034 - val_loss: 2.4805 - learning_rate: 0.0100
Epoch 2/50
  1/120 ━━━━━━━━━━━━━━━━━━━━ 2:09 1s/step - accuracy: 0.1250 - loss: 2.2314
Epoch 2: val_accuracy improved from 0.10337 to 0.16667, saving model to /content/drive/MyDrive/AI_ML_PGP/Projects/PlantSeedlingsClassification/cnn_model_sgd_aug.keras
120/120 ━━━━━━━━━━━━━━━━━━━━ 5s 36ms/step - accuracy: 0.1250 - loss: 2.2314 - val_accuracy: 0.1667 - val_loss: 2.4530 - learning_rate: 0.0100
Epoch 3/50
120/120 ━━━━━━━━━━━━━━━━━━━━ 0s 1s/step - accuracy: 0.1043 - loss: 2.4698
Epoch 3: val_accuracy did not improve from 0.16667
120/120 ━━━━━━━━━━━━━━━━━━━━ 194s 1s/step - accuracy: 0.1044 - loss: 2.4695 - val_accuracy: 0.0505 - val_loss: 2.4543 - learning_rate: 0.0100
Epoch 4/50
  1/120 ━━━━━━━━━━━━━━━━━━━━ 1:50 928ms/step - accuracy: 0.1250 - loss: 2.1805
Epoch 4: val_accuracy did not improve from 0.16667
120/120 ━━━━━━━━━━━━━━━━━━━━ 1s 752us/step - accuracy: 0.1250 - loss: 2.1805 - val_accuracy: 0.0833 - val_loss: 2.4743 - learning_rate: 0.0100
Epoch 5/50
120/120 ━━━━━━━━━━━━━━━━━━━━ 0s 1s/step - accuracy: 0.1561 - loss: 2.2209
Epoch 5: val_accuracy did not improve from 0.16667
120/120 ━━━━━━━━━━━━━━━━━━━━ 202s 1s/step - accuracy: 0.1562 - loss: 2.2203 - val_accuracy: 0.1226 - val_loss: 2.1587 - learning_rate: 0.0100
Epoch 6/50
  1/120 ━━━━━━━━━━━━━━━━━━━━ 1:53 950ms/step - accuracy: 0.3750 - loss: 1.9730
Epoch 6: val_accuracy did not improve from 0.16667
120/120 ━━━━━━━━━━━━━━━━━━━━ 1s 742us/step - accuracy: 0.3750 - loss: 1.9730 - val_accuracy: 0.1667 - val_loss: 2.1088 - learning_rate: 0.0100
Epoch 7/50
120/120 ━━━━━━━━━━━━━━━━━━━━ 0s 1s/step - accuracy: 0.2232 - loss: 2.0221
Epoch 7: val_accuracy did not improve from 0.16667
120/120 ━━━━━━━━━━━━━━━━━━━━ 140s 1s/step - accuracy: 0.2234 - loss: 2.0218 - val_accuracy: 0.1298 - val_loss: 2.1430 - learning_rate: 0.0100
Epoch 8/50
  1/120 ━━━━━━━━━━━━━━━━━━━━ 3:07 2s/step - accuracy: 0.3750 - loss: 1.6719
Epoch 8: val_accuracy did not improve from 0.16667
120/120 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.3750 - loss: 1.6719 - val_accuracy: 0.0000e+00 - val_loss: 2.2259 - learning_rate: 0.0100
Epoch 9/50
120/120 ━━━━━━━━━━━━━━━━━━━━ 0s 1s/step - accuracy: 0.3190 - loss: 1.8030
Epoch 9: val_accuracy improved from 0.16667 to 0.19952, saving model to /content/drive/MyDrive/AI_ML_PGP/Projects/PlantSeedlingsClassification/cnn_model_sgd_aug.keras
120/120 ━━━━━━━━━━━━━━━━━━━━ 202s 1s/step - accuracy: 0.3190 - loss: 1.8029 - val_accuracy: 0.1995 - val_loss: 2.0950 - learning_rate: 0.0100
Epoch 10/50
  1/120 ━━━━━━━━━━━━━━━━━━━━ 2:35 1s/step - accuracy: 0.2188 - loss: 1.7166
Epoch 10: val_accuracy did not improve from 0.19952
120/120 ━━━━━━━━━━━━━━━━━━━━ 1s 1ms/step - accuracy: 0.2188 - loss: 1.7166 - val_accuracy: 0.1667 - val_loss: 1.7379 - learning_rate: 0.0100
Epoch 11/50
120/120 ━━━━━━━━━━━━━━━━━━━━ 0s 1s/step - accuracy: 0.3481 - loss: 1.7062
Epoch 11: val_accuracy improved from 0.19952 to 0.34135, saving model to /content/drive/MyDrive/AI_ML_PGP/Projects/PlantSeedlingsClassification/cnn_model_sgd_aug.keras
120/120 ━━━━━━━━━━━━━━━━━━━━ 145s 1s/step - accuracy: 0.3484 - loss: 1.7060 - val_accuracy: 0.3413 - val_loss: 1.8474 - learning_rate: 0.0100
Epoch 12/50
  1/120 ━━━━━━━━━━━━━━━━━━━━ 1:57 987ms/step - accuracy: 0.4062 - loss: 1.8183
Epoch 12: val_accuracy improved from 0.34135 to 0.58333, saving model to /content/drive/MyDrive/AI_ML_PGP/Projects/PlantSeedlingsClassification/cnn_model_sgd_aug.keras
120/120 ━━━━━━━━━━━━━━━━━━━━ 1s 1ms/step - accuracy: 0.4062 - loss: 1.8183 - val_accuracy: 0.5833 - val_loss: 1.4858 - learning_rate: 0.0100
Epoch 13/50
120/120 ━━━━━━━━━━━━━━━━━━━━ 0s 1s/step - accuracy: 0.3938 - loss: 1.6537
Epoch 13: val_accuracy did not improve from 0.58333
120/120 ━━━━━━━━━━━━━━━━━━━━ 198s 1s/step - accuracy: 0.3937 - loss: 1.6536 - val_accuracy: 0.3510 - val_loss: 2.2470 - learning_rate: 0.0100
Epoch 14/50
  1/120 ━━━━━━━━━━━━━━━━━━━━ 2:14 1s/step - accuracy: 0.3750 - loss: 1.8222
Epoch 14: val_accuracy did not improve from 0.58333
120/120 ━━━━━━━━━━━━━━━━━━━━ 1s 1ms/step - accuracy: 0.3750 - loss: 1.8222 - val_accuracy: 0.5833 - val_loss: 1.5988 - learning_rate: 0.0100
Epoch 15/50
120/120 ━━━━━━━━━━━━━━━━━━━━ 0s 1s/step - accuracy: 0.4432 - loss: 1.5709
Epoch 15: val_accuracy did not improve from 0.58333
120/120 ━━━━━━━━━━━━━━━━━━━━ 143s 1s/step - accuracy: 0.4432 - loss: 1.5706 - val_accuracy: 0.1755 - val_loss: 3.3360 - learning_rate: 0.0100
Epoch 16/50
  1/120 ━━━━━━━━━━━━━━━━━━━━ 3:02 2s/step - accuracy: 0.3125 - loss: 1.7872
Epoch 16: val_accuracy did not improve from 0.58333
120/120 ━━━━━━━━━━━━━━━━━━━━ 6s 41ms/step - accuracy: 0.3125 - loss: 1.7872 - val_accuracy: 0.0833 - val_loss: 3.1351 - learning_rate: 0.0100
Epoch 17/50
120/120 ━━━━━━━━━━━━━━━━━━━━ 0s 1s/step - accuracy: 0.4820 - loss: 1.4503
Epoch 17: val_accuracy did not improve from 0.58333

Epoch 17: ReduceLROnPlateau reducing learning rate to 0.0009999999776482583.
120/120 ━━━━━━━━━━━━━━━━━━━━ 143s 1s/step - accuracy: 0.4821 - loss: 1.4503 - val_accuracy: 0.3654 - val_loss: 1.7746 - learning_rate: 0.0100
Epoch 18/50
  1/120 ━━━━━━━━━━━━━━━━━━━━ 1:51 939ms/step - accuracy: 0.4375 - loss: 1.5296
Epoch 18: val_accuracy did not improve from 0.58333
120/120 ━━━━━━━━━━━━━━━━━━━━ 1s 802us/step - accuracy: 0.4375 - loss: 1.5296 - val_accuracy: 0.5000 - val_loss: 1.3048 - learning_rate: 1.0000e-03
Epoch 19/50
120/120 ━━━━━━━━━━━━━━━━━━━━ 0s 1s/step - accuracy: 0.4577 - loss: 1.4494
Epoch 19: val_accuracy improved from 0.58333 to 0.63702, saving model to /content/drive/MyDrive/AI_ML_PGP/Projects/PlantSeedlingsClassification/cnn_model_sgd_aug.keras
120/120 ━━━━━━━━━━━━━━━━━━━━ 200s 1s/step - accuracy: 0.4580 - loss: 1.4488 - val_accuracy: 0.6370 - val_loss: 1.1355 - learning_rate: 1.0000e-03
Epoch 20/50
  1/120 ━━━━━━━━━━━━━━━━━━━━ 1:53 958ms/step - accuracy: 0.4688 - loss: 1.7354
Epoch 20: val_accuracy improved from 0.63702 to 0.83333, saving model to /content/drive/MyDrive/AI_ML_PGP/Projects/PlantSeedlingsClassification/cnn_model_sgd_aug.keras
120/120 ━━━━━━━━━━━━━━━━━━━━ 1s 1ms/step - accuracy: 0.4688 - loss: 1.7354 - val_accuracy: 0.8333 - val_loss: 0.8676 - learning_rate: 1.0000e-03
Epoch 21/50
120/120 ━━━━━━━━━━━━━━━━━━━━ 0s 1s/step - accuracy: 0.5378 - loss: 1.2904
Epoch 21: val_accuracy did not improve from 0.83333
120/120 ━━━━━━━━━━━━━━━━━━━━ 143s 1s/step - accuracy: 0.5379 - loss: 1.2904 - val_accuracy: 0.6587 - val_loss: 1.0644 - learning_rate: 1.0000e-03
Epoch 22/50
  1/120 ━━━━━━━━━━━━━━━━━━━━ 2:24 1s/step - accuracy: 0.5625 - loss: 0.9808
Epoch 22: val_accuracy improved from 0.83333 to 0.91667, saving model to /content/drive/MyDrive/AI_ML_PGP/Projects/PlantSeedlingsClassification/cnn_model_sgd_aug.keras
120/120 ━━━━━━━━━━━━━━━━━━━━ 2s 11ms/step - accuracy: 0.5625 - loss: 0.9808 - val_accuracy: 0.9167 - val_loss: 0.7902 - learning_rate: 1.0000e-03
Epoch 23/50
120/120 ━━━━━━━━━━━━━━━━━━━━ 0s 1s/step - accuracy: 0.5535 - loss: 1.2949
Epoch 23: val_accuracy did not improve from 0.91667
120/120 ━━━━━━━━━━━━━━━━━━━━ 199s 1s/step - accuracy: 0.5536 - loss: 1.2949 - val_accuracy: 0.5889 - val_loss: 1.1672 - learning_rate: 1.0000e-03
Epoch 24/50
  1/120 ━━━━━━━━━━━━━━━━━━━━ 1:54 959ms/step - accuracy: 0.6562 - loss: 1.1938
Epoch 24: val_accuracy did not improve from 0.91667
120/120 ━━━━━━━━━━━━━━━━━━━━ 1s 1ms/step - accuracy: 0.6562 - loss: 1.1938 - val_accuracy: 0.8333 - val_loss: 0.7754 - learning_rate: 1.0000e-03
Epoch 25/50
120/120 ━━━━━━━━━━━━━━━━━━━━ 0s 1s/step - accuracy: 0.5684 - loss: 1.2429
Epoch 25: val_accuracy did not improve from 0.91667
120/120 ━━━━━━━━━━━━━━━━━━━━ 202s 1s/step - accuracy: 0.5684 - loss: 1.2428 - val_accuracy: 0.6418 - val_loss: 1.0939 - learning_rate: 1.0000e-03
Epoch 26/50
  1/120 ━━━━━━━━━━━━━━━━━━━━ 1:53 952ms/step - accuracy: 0.5625 - loss: 1.0305
Epoch 26: val_accuracy did not improve from 0.91667
120/120 ━━━━━━━━━━━━━━━━━━━━ 1s 740us/step - accuracy: 0.5625 - loss: 1.0305 - val_accuracy: 0.9167 - val_loss: 0.7818 - learning_rate: 1.0000e-03
Epoch 27/50
120/120 ━━━━━━━━━━━━━━━━━━━━ 0s 1s/step - accuracy: 0.5697 - loss: 1.2399
Epoch 27: val_accuracy did not improve from 0.91667
120/120 ━━━━━━━━━━━━━━━━━━━━ 199s 1s/step - accuracy: 0.5698 - loss: 1.2395 - val_accuracy: 0.6466 - val_loss: 1.1010 - learning_rate: 1.0000e-03
Epoch 28/50
  1/120 ━━━━━━━━━━━━━━━━━━━━ 1:55 972ms/step - accuracy: 0.4688 - loss: 1.4970
Epoch 28: val_accuracy did not improve from 0.91667
120/120 ━━━━━━━━━━━━━━━━━━━━ 1s 818us/step - accuracy: 0.4688 - loss: 1.4970 - val_accuracy: 0.9167 - val_loss: 0.6294 - learning_rate: 1.0000e-03
Epoch 29/50
120/120 ━━━━━━━━━━━━━━━━━━━━ 0s 1s/step - accuracy: 0.5907 - loss: 1.1721
Epoch 29: val_accuracy did not improve from 0.91667
120/120 ━━━━━━━━━━━━━━━━━━━━ 142s 1s/step - accuracy: 0.5906 - loss: 1.1722 - val_accuracy: 0.6611 - val_loss: 0.9897 - learning_rate: 1.0000e-03
Epoch 30/50
  1/120 ━━━━━━━━━━━━━━━━━━━━ 1:53 957ms/step - accuracy: 0.6250 - loss: 1.2735
Epoch 30: val_accuracy did not improve from 0.91667
120/120 ━━━━━━━━━━━━━━━━━━━━ 1s 699us/step - accuracy: 0.6250 - loss: 1.2735 - val_accuracy: 0.8333 - val_loss: 0.8007 - learning_rate: 1.0000e-03
Epoch 31/50
120/120 ━━━━━━━━━━━━━━━━━━━━ 0s 1s/step - accuracy: 0.5708 - loss: 1.2129
Epoch 31: val_accuracy did not improve from 0.91667
120/120 ━━━━━━━━━━━━━━━━━━━━ 141s 1s/step - accuracy: 0.5709 - loss: 1.2127 - val_accuracy: 0.6418 - val_loss: 1.0644 - learning_rate: 1.0000e-03
Epoch 32/50
  1/120 ━━━━━━━━━━━━━━━━━━━━ 1:56 976ms/step - accuracy: 0.6562 - loss: 1.1153
Epoch 32: val_accuracy did not improve from 0.91667
120/120 ━━━━━━━━━━━━━━━━━━━━ 1s 677us/step - accuracy: 0.6562 - loss: 1.1153 - val_accuracy: 0.8333 - val_loss: 0.8109 - learning_rate: 1.0000e-03
Epoch 33/50
120/120 ━━━━━━━━━━━━━━━━━━━━ 0s 1s/step - accuracy: 0.5988 - loss: 1.1558
Epoch 33: val_accuracy did not improve from 0.91667

Epoch 33: ReduceLROnPlateau reducing learning rate to 9.999999310821295e-05.
120/120 ━━━━━━━━━━━━━━━━━━━━ 143s 1s/step - accuracy: 0.5986 - loss: 1.1560 - val_accuracy: 0.7043 - val_loss: 0.9720 - learning_rate: 1.0000e-03
Epoch 34/50
  1/120 ━━━━━━━━━━━━━━━━━━━━ 3:10 2s/step - accuracy: 0.6250 - loss: 1.1716
Epoch 34: val_accuracy did not improve from 0.91667
120/120 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.6250 - loss: 1.1716 - val_accuracy: 0.9167 - val_loss: 0.5967 - learning_rate: 1.0000e-04
Epoch 35/50
120/120 ━━━━━━━━━━━━━━━━━━━━ 0s 1s/step - accuracy: 0.6207 - loss: 1.1228
Epoch 35: val_accuracy did not improve from 0.91667
120/120 ━━━━━━━━━━━━━━━━━━━━ 144s 1s/step - accuracy: 0.6207 - loss: 1.1230 - val_accuracy: 0.6947 - val_loss: 0.9382 - learning_rate: 1.0000e-04
Epoch 36/50
  1/120 ━━━━━━━━━━━━━━━━━━━━ 1:57 991ms/step - accuracy: 0.5000 - loss: 1.3786
Epoch 36: val_accuracy did not improve from 0.91667
120/120 ━━━━━━━━━━━━━━━━━━━━ 1s 720us/step - accuracy: 0.5000 - loss: 1.3786 - val_accuracy: 0.9167 - val_loss: 0.5703 - learning_rate: 1.0000e-04
Epoch 37/50
120/120 ━━━━━━━━━━━━━━━━━━━━ 0s 1s/step - accuracy: 0.6231 - loss: 1.1243
Epoch 37: val_accuracy did not improve from 0.91667
120/120 ━━━━━━━━━━━━━━━━━━━━ 199s 1s/step - accuracy: 0.6231 - loss: 1.1244 - val_accuracy: 0.6923 - val_loss: 0.9643 - learning_rate: 1.0000e-04
Epoch 38/50
  1/120 ━━━━━━━━━━━━━━━━━━━━ 1:57 984ms/step - accuracy: 0.5938 - loss: 1.1645
Epoch 38: val_accuracy did not improve from 0.91667
120/120 ━━━━━━━━━━━━━━━━━━━━ 1s 777us/step - accuracy: 0.5938 - loss: 1.1645 - val_accuracy: 0.9167 - val_loss: 0.5992 - learning_rate: 1.0000e-04
Epoch 39/50
120/120 ━━━━━━━━━━━━━━━━━━━━ 0s 1s/step - accuracy: 0.6055 - loss: 1.1450
Epoch 39: val_accuracy did not improve from 0.91667
120/120 ━━━━━━━━━━━━━━━━━━━━ 141s 1s/step - accuracy: 0.6054 - loss: 1.1450 - val_accuracy: 0.6923 - val_loss: 0.9573 - learning_rate: 1.0000e-04
Epoch 40/50
  1/120 ━━━━━━━━━━━━━━━━━━━━ 3:10 2s/step - accuracy: 0.6562 - loss: 1.1219
Epoch 40: val_accuracy did not improve from 0.91667
120/120 ━━━━━━━━━━━━━━━━━━━━ 7s 44ms/step - accuracy: 0.6562 - loss: 1.1219 - val_accuracy: 0.9167 - val_loss: 0.5876 - learning_rate: 1.0000e-04
Epoch 41/50
120/120 ━━━━━━━━━━━━━━━━━━━━ 0s 1s/step - accuracy: 0.6060 - loss: 1.1250
Epoch 41: val_accuracy did not improve from 0.91667

Epoch 41: ReduceLROnPlateau reducing learning rate to 9.999999019782991e-06.
120/120 ━━━━━━━━━━━━━━━━━━━━ 142s 1s/step - accuracy: 0.6061 - loss: 1.1251 - val_accuracy: 0.6899 - val_loss: 0.9446 - learning_rate: 1.0000e-04
Epoch 42/50
  1/120 ━━━━━━━━━━━━━━━━━━━━ 1:57 987ms/step - accuracy: 0.5938 - loss: 1.2995
Epoch 42: val_accuracy did not improve from 0.91667
120/120 ━━━━━━━━━━━━━━━━━━━━ 1s 707us/step - accuracy: 0.5938 - loss: 1.2995 - val_accuracy: 0.9167 - val_loss: 0.6521 - learning_rate: 1.0000e-05
Epoch 43/50
120/120 ━━━━━━━━━━━━━━━━━━━━ 0s 1s/step - accuracy: 0.5854 - loss: 1.1776
Epoch 43: val_accuracy did not improve from 0.91667
120/120 ━━━━━━━━━━━━━━━━━━━━ 141s 1s/step - accuracy: 0.5855 - loss: 1.1773 - val_accuracy: 0.6995 - val_loss: 0.9294 - learning_rate: 1.0000e-05
Epoch 44/50
  1/120 ━━━━━━━━━━━━━━━━━━━━ 1:54 965ms/step - accuracy: 0.6250 - loss: 0.9305
Epoch 44: val_accuracy did not improve from 0.91667
120/120 ━━━━━━━━━━━━━━━━━━━━ 1s 665us/step - accuracy: 0.6250 - loss: 0.9305 - val_accuracy: 0.9167 - val_loss: 0.6413 - learning_rate: 1.0000e-05
Epoch 45/50
120/120 ━━━━━━━━━━━━━━━━━━━━ 0s 1s/step - accuracy: 0.6006 - loss: 1.1616
Epoch 45: val_accuracy did not improve from 0.91667
120/120 ━━━━━━━━━━━━━━━━━━━━ 141s 1s/step - accuracy: 0.6007 - loss: 1.1613 - val_accuracy: 0.6875 - val_loss: 0.9508 - learning_rate: 1.0000e-05
Epoch 46/50
  1/120 ━━━━━━━━━━━━━━━━━━━━ 2:58 1s/step - accuracy: 0.6250 - loss: 1.3666
Epoch 46: val_accuracy did not improve from 0.91667

Epoch 46: ReduceLROnPlateau reducing learning rate to 1e-06.
120/120 ━━━━━━━━━━━━━━━━━━━━ 2s 1ms/step - accuracy: 0.6250 - loss: 1.3666 - val_accuracy: 0.9167 - val_loss: 0.6030 - learning_rate: 1.0000e-05
Epoch 46: early stopping
Restoring model weights from the end of the best epoch: 36.
Plotting epochs vs accuracy¶
In [96]:
plt.figure(figsize=(8,8))
plt.plot(history.history['accuracy'])
plt.plot(history.history['val_accuracy'])
plt.title('Model Accuracy')
plt.legend(['Train', 'Validation'], loc='upper left')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.show()
No description has been provided for this image

Loading the saved model¶

In [97]:
cnn_model_aug = tf.keras.models.load_model('/content/drive/MyDrive/AI_ML_PGP/Projects/PlantSeedlingsClassification/cnn_model_sgd_aug.keras')

Model evaluation and performance¶

In [101]:
# model evaluation
cnn_model_aug.evaluate(X_test_norm, y_test_encoded)
15/15 ━━━━━━━━━━━━━━━━━━━━ 4s 256ms/step - accuracy: 0.6759 - loss: 1.0138
Out[101]:
[1.0811728239059448, 0.6294736862182617]
In [102]:
# prediction
y_pred = cnn_model_aug.predict(X_test_norm)
y_pred_classes = np.argmax(y_pred, axis=1)
y_test_classes = np.argmax(y_test_encoded, axis=1)
accuracy_score(y_test_classes, y_pred_classes)
15/15 ━━━━━━━━━━━━━━━━━━━━ 5s 344ms/step
Out[102]:
0.6294736842105263
In [103]:
# confusion matrix
cm = confusion_matrix(y_test_classes, y_pred_classes)
cf_matrix_norm = cm / np.sum(cm, axis=1)
plt.figure(figsize=(10,10))
sns.heatmap(cf_matrix_norm, annot=True, xticklabels=categories, yticklabels=categories, cmap='coolwarm')
plt.title('Confusion Matrix')
plt.ylabel('True label')
plt.xlabel('Predicted label')
plt.show()
No description has been provided for this image
In [104]:
# classification report
class_report = classification_report(y_test_classes, y_pred_classes, target_names=categories)
print(class_report)
                           precision    recall  f1-score   support

              Black-grass       0.34      0.50      0.41        26
                 Charlock       0.62      0.92      0.74        39
                 Cleavers       0.78      0.72      0.75        29
         Common Chickweed       0.76      0.85      0.81        61
             Common wheat       0.57      0.73      0.64        22
                  Fat Hen       0.62      0.21      0.31        48
         Loose Silky-bent       0.70      0.65      0.67        65
                    Maize       0.64      0.32      0.42        22
        Scentless Mayweed       0.55      0.52      0.53        52
          Shepherds Purse       0.29      0.26      0.27        23
Small-flowered Cranesbill       0.80      0.82      0.81        50
               Sugar beet       0.58      0.74      0.65        38

                 accuracy                           0.63       475
                macro avg       0.61      0.60      0.59       475
             weighted avg       0.64      0.63      0.61       475

Observations¶

  • Above we tried data augmentation with class weights. We didn't achieve a good performance. We noticed that the model achieved around 56% train accuracy and 91% valdation accuracy. This suggests underfitting during training.
  • The model achieved around 62% accuracy in the test dataset, generalizing better than the validation dataset.
  • The above suggests one or both of the following:
    • The validation data is too small
    • There are overlapping similarities or data leakage between validation and training dataset
    • The validation data is biased
  • We could have built another model by changing the data split, however, we achieved good success without data augementation. So we are going to choose our final model from the first two models we built.

Final Model¶

  • We built 3 models. All 3 models have same architecture with 4 convolutional layers and 3 fully connected layers. All these models have 181,996 trainable parameters.

  • Model 1: We used class_weights to give higher importance to the minority classes. We used adam optimizer. This model achieved around 91% accuracy in the training data, 87% accuracy in the validation dataset and around 84% accuracy in the test dataset. Though the model performs well, it indicates slight overfitting.

  • Model 2: We used sgd optimizer with ReducedLrOnPlateaue to achive better performance. We observed 92% accuracy in the training dataset, around 88% accuracy in the validation datset and around 89% accuracy in the test dataset. We can conclude this model performed well in all 3 dataset and it is generalizing well.

  • Model 3: We used dynamic data augmentation along with class weights. We kept the optimizer and the architecture same as the previoud model. We observed only 56% accuracy in the train dataset, 91% accuracy in the validation dataset and 62% accuracy in test dataset. This suggests underfitting during training and probably an issue with the validation data-split.

  • Final Model: Model 2 is our final chosen model as it has very good performance and it is generalizing well.

Loading the final model and making a random prediction¶

In [28]:
final_model = tf.keras.models.load_model('/content/drive/MyDrive/AI_ML_PGP/Projects/PlantSeedlingsClassification/cnn_model_sgd.keras')
In [30]:
y_pred = final_model.predict(X_test_norm)
y_pred_classes = np.argmax(y_pred, axis=1)
y_test_classes = np.argmax(y_test_encoded, axis=1)
accuracy_score(y_test_classes, y_pred_classes)
15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step
Out[30]:
0.8926315789473684
In [33]:
# get a random image index
random_index = np.random.randint(0, len(X_test_norm))
random_index
Out[33]:
348
In [34]:
# prediction
result = final_model.predict(np.expand_dims(X_test_norm[random_index], axis=0))
result
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 212ms/step
Out[34]:
array([[3.9270812e-01, 1.7879278e-05, 4.5150420e-05, 1.0130113e-04,
        1.2365825e-03, 4.7762040e-03, 6.0056323e-01, 4.8438801e-06,
        4.6698211e-04, 7.8904714e-06, 4.9278548e-05, 2.2528231e-05]],
      dtype=float32)
In [39]:
i = np.argmax(result)
y_test_classes[random_index], categories[i]
Out[39]:
(6, 'Loose Silky-bent')

Visualizing the prediction¶

In [55]:
actual = y_test.iloc[random_index]['Label']
In [59]:
plt.imshow(X_test[random_index])
plt.title(f'Predicted: {categories[i]} \nActual:{actual} \nProbability: {result[0][i]:.2f}')
plt.axis('off')
plt.show()
No description has been provided for this image

Actionable Insights and Business Recommendations¶

  • Our final model gives us 89% accuracy in the test dataset.
  • We noticed the recall for Black-grass and Shepherds Purse are low, 65% and 70% respectively.
  • The above could be due to smaller dataset and image quality.
  • We noticed the model is wrongly predicting Loosly Silky-bent as Black-grass.
  • Increasing the sample size of Black-grass and improving the image quality would help to improve the recall score.
  • The same is applicable for Shepherds Purse as the sample size for it is also small. We noticed Scentless Mayweed and Small-flowered Cranesbill are wrongly predicted as Shepherds Purse
  • Here, we didn't fine tune a pre-trained model like VGG-16, RestNet or EfficientNet. If the recommended model doesn't perform well against the real world data, we recommend fine-tuning a pre-trained model that has good performance for this kind of dataset.